{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MhoQ0WE77laV"
   },
   "source": [
    "##### Copyright 2019 The TensorFlow Authors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2023-11-07T23:20:27.275480Z",
     "iopub.status.busy": "2023-11-07T23:20:27.275070Z",
     "iopub.status.idle": "2023-11-07T23:20:27.279330Z",
     "shell.execute_reply": "2023-11-07T23:20:27.278476Z"
    },
    "id": "_ckMIh7O7s6D"
   },
   "outputs": [],
   "source": [
    "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jYysdyb-CaWM"
   },
   "source": [
    "# 使用 tf.distribute.Strategy 进行自定义训练"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FbVhjPpzn6BM"
   },
   "source": [
    "本教程演示了如何使用具有自定义训练循环的 TensorFlow API `tf.distribute.Strategy`，它提供了一种用于在多个处理单元（GPU、多台机器或 TPU）之间[分配训练](../../guide/distributed_training.ipynb)的抽象。在此示例中，将在 [Fashion MNIST 数据集](https://github.com/zalandoresearch/fashion-mnist)上训练一个简单的卷积神经网络，此数据集包含 70,000 个大小为 28 x 28 的图像。\n",
    "\n",
    "[自定义训练循环](../customization/custom_training_walkthrough.ipynb)提供了灵活性并且能够更好地控制训练。此外，它们也让调试模型和训练循环更加容易。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:20:27.283520Z",
     "iopub.status.busy": "2023-11-07T23:20:27.283029Z",
     "iopub.status.idle": "2023-11-07T23:20:29.910971Z",
     "shell.execute_reply": "2023-11-07T23:20:29.909891Z"
    },
    "id": "dzLKpmZICaWN"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-11-07 23:20:27.755048: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2023-11-07 23:20:27.755105: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2023-11-07 23:20:27.756847: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.15.0-rc1\n"
     ]
    }
   ],
   "source": [
    "# Import TensorFlow\n",
    "import tensorflow as tf\n",
    "\n",
    "# Helper libraries\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MM6W__qraV55"
   },
   "source": [
    "## 下载 Fashion MNIST 数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:20:29.915136Z",
     "iopub.status.busy": "2023-11-07T23:20:29.914619Z",
     "iopub.status.idle": "2023-11-07T23:20:30.458794Z",
     "shell.execute_reply": "2023-11-07T23:20:30.457651Z"
    },
    "id": "7MqDQO0KCaWS"
   },
   "outputs": [],
   "source": [
    "fashion_mnist = tf.keras.datasets.fashion_mnist\n",
    "\n",
    "(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()\n",
    "\n",
    "# Add a dimension to the array -> new shape == (28, 28, 1)\n",
    "# This is done because the first layer in our model is a convolutional\n",
    "# layer and it requires a 4D input (batch_size, height, width, channels).\n",
    "# batch_size dimension will be added later on.\n",
    "train_images = train_images[..., None]\n",
    "test_images = test_images[..., None]\n",
    "\n",
    "# Scale the images to the [0, 1] range.\n",
    "train_images = train_images / np.float32(255)\n",
    "test_images = test_images / np.float32(255)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4AXoHhrsbdF3"
   },
   "source": [
    "## 创建一个分发变量和图形的策略"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5mVuLZhbem8d"
   },
   "source": [
    "`tf.distribute.MirroredStrategy` 策略是如何运作的？\n",
    "\n",
    "- 所有变量和模型计算图都会在副本之间复制。\n",
    "- 输入都均匀分布在副本中。\n",
    "- 每个副本在收到输入后计算输入的损失和梯度。\n",
    "- 通过求和，每一个副本上的梯度都能同步。\n",
    "- 同步后，每个副本上的复制的变量都可以同样更新。\n",
    "\n",
    "注：您可以将下面的所有代码放在单个作用域内。出于说明目的，本示例将它分为几个代码单元。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:20:30.463591Z",
     "iopub.status.busy": "2023-11-07T23:20:30.463288Z",
     "iopub.status.idle": "2023-11-07T23:20:33.757568Z",
     "shell.execute_reply": "2023-11-07T23:20:33.756686Z"
    },
    "id": "F2VeZUWUj5S4"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n"
     ]
    }
   ],
   "source": [
    "# If the list of devices is not specified in\n",
    "# `tf.distribute.MirroredStrategy` constructor, they will be auto-detected.\n",
    "strategy = tf.distribute.MirroredStrategy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:20:33.761962Z",
     "iopub.status.busy": "2023-11-07T23:20:33.761634Z",
     "iopub.status.idle": "2023-11-07T23:20:33.766079Z",
     "shell.execute_reply": "2023-11-07T23:20:33.765396Z"
    },
    "id": "ZngeM_2o0_JO"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of devices: 4\n"
     ]
    }
   ],
   "source": [
    "print('Number of devices: {}'.format(strategy.num_replicas_in_sync))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "k53F5I_IiGyI"
   },
   "source": [
    "## 设置输入流水线"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:20:33.769604Z",
     "iopub.status.busy": "2023-11-07T23:20:33.769316Z",
     "iopub.status.idle": "2023-11-07T23:20:33.773113Z",
     "shell.execute_reply": "2023-11-07T23:20:33.772417Z"
    },
    "id": "jwJtsCQhHK-E"
   },
   "outputs": [],
   "source": [
    "BUFFER_SIZE = len(train_images)\n",
    "\n",
    "BATCH_SIZE_PER_REPLICA = 64\n",
    "GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync\n",
    "\n",
    "EPOCHS = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "J7fj3GskHC8g"
   },
   "source": [
    "创建数据集并分发它们："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:20:33.776346Z",
     "iopub.status.busy": "2023-11-07T23:20:33.776062Z",
     "iopub.status.idle": "2023-11-07T23:20:34.580169Z",
     "shell.execute_reply": "2023-11-07T23:20:34.579308Z"
    },
    "id": "WYrMNNDhAvVl"
   },
   "outputs": [],
   "source": [
    "train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) \n",
    "test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE) \n",
    "\n",
    "train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)\n",
    "test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bAXAo_wWbWSb"
   },
   "source": [
    "## 创建模型\n",
    "\n",
    "使用 `tf.keras.Sequential` 创建模型。也可以使用[模型子类化 API](https://tensorflow.google.cn/guide/keras/custom_layers_and_models) 或[函数式 API](https://tensorflow.google.cn/guide/keras/functional) 来完成此操作。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:20:34.584630Z",
     "iopub.status.busy": "2023-11-07T23:20:34.584325Z",
     "iopub.status.idle": "2023-11-07T23:20:34.589279Z",
     "shell.execute_reply": "2023-11-07T23:20:34.588575Z"
    },
    "id": "9ODch-OFCaW4"
   },
   "outputs": [],
   "source": [
    "def create_model():\n",
    "  model = tf.keras.Sequential([\n",
    "      tf.keras.layers.Conv2D(32, 3, activation='relu'),\n",
    "      tf.keras.layers.MaxPooling2D(),\n",
    "      tf.keras.layers.Conv2D(64, 3, activation='relu'),\n",
    "      tf.keras.layers.MaxPooling2D(),\n",
    "      tf.keras.layers.Flatten(),\n",
    "      tf.keras.layers.Dense(64, activation='relu'),\n",
    "      tf.keras.layers.Dense(10)\n",
    "    ])\n",
    "\n",
    "  return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:20:34.592538Z",
     "iopub.status.busy": "2023-11-07T23:20:34.592261Z",
     "iopub.status.idle": "2023-11-07T23:20:34.595992Z",
     "shell.execute_reply": "2023-11-07T23:20:34.595314Z"
    },
    "id": "9iagoTBfijUz"
   },
   "outputs": [],
   "source": [
    "# Create a checkpoint directory to store the checkpoints.\n",
    "checkpoint_dir = './training_checkpoints'\n",
    "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0-VVTqDEICrl"
   },
   "source": [
    "## 定义损失函数\n",
    "\n",
    "通常，在具有单个 GPU/CPU 的单台机器上，损失函数会除以输入批次中的样本数。\n",
    "\n",
    "*因此，使用 `tf.distribute.Strategy` 时应如何计算损失？*\n",
    "\n",
    "- 例如，假设有 4 个 GPU，批量大小为 64。一个批次的输入会分布在各个副本（4 个 GPU）上，每个副本获得一个大小为 16 的输入。\n",
    "\n",
    "- 每个副本上的模型都会使用其各自的输入进行前向传递，并计算损失。现在，不将损失除以其相应输入中的样本数 (BATCH_SIZE_PER_REPLICA = 16)，而应将损失除以 GLOBAL_BATCH_SIZE (64)。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "OCIcsaeoIHJX"
   },
   "source": [
    "*为什么这样做？*\n",
    "\n",
    "- 之所以需要这样做，是因为在每个副本上计算完梯度后，会通过对梯度**求和**在副本之间同步梯度。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "e-wlFFZbP33n"
   },
   "source": [
    "*如何在 TensorFlow 中执行此操作？*\n",
    "\n",
    "- 如果您正在编写自定义训练循环（如本教程中所述），则应将每个样本的损失相加，然后将总和除以 GLOBAL_BATCH_SIZE: `scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE)`，或者您可以使用 `tf.nn.compute_average_loss`，它会将每个样本的损失、可选样本权重和 GLOBAL_BATCH_SIZE 作为参数，并返回经过缩放的损失。\n",
    "\n",
    "- 如果在模型中使用正则化损失，则需要按副本数缩放损失值。可以使用 `tf.nn.scale_regularization_loss` 函数进行此操作。\n",
    "\n",
    "- 不建议使用 `tf.reduce_mean`。这样做会将损失除以实际的每个副本批次大小，该大小可能会随着步骤的不同而发生变化。\n",
    "\n",
    "- 这种归约和缩放会在 Keras `Model.compile` 和 `Model.fit` 中自动完成。\n",
    "\n",
    "- 如果使用 `tf.keras.losses` 类（如下面的示例所示），则需要将损失归约显式指定为 `NONE` 或 `SUM`。与 `tf.distribute.Strategy` 一起使用时，不允许使用 `AUTO` 和 `SUM_OVER_BATCH_SIZE`。不允许使用 `AUTO`，因为用户应明确考虑他们想要的归约量，以确保在分布式情况下归约量正确。不允许使用 `SUM_OVER_BATCH_SIZE`，因为当前它只能按副本批次大小进行划分，而将按副本数量划分留给用户，这可能很容易遗漏。因此，您需要自己显式执行归约操作。\n",
    "\n",
    "- 如果 `labels` 为多维，则对每个样本中的元素数量的 `per_example_loss` 求平均值。例如，如果 `predictions` 的形状为 `(batch_size, H, W, n_classes)`，而 `labels` 为 `(batch_size, H, W)`，则需要更新 `per_example_loss`，例如：`per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)`\n",
    "\n",
    "    小心：**验证损失的形状**。`tf.losses`/`tf.keras.losses` 中的损失函数通常会返回输入最后一个维度的平均值。损失类封装这些函数。在创建损失类的实例时传递 `reduction=Reduction.NONE`，表示“无**额外**缩减”。对于样本输入形状为 `[batch, W, H, n_classes]` 的类别损失，会缩减 `n_classes` 维度。对于类似 `losses.mean_squared_error` 或 `losses.binary_crossentropy` 的逐点损失，应包含一个虚拟轴，使 `[batch, W, H, 1]` 缩减为 `[batch, W, H]`。如果没有虚拟轴，`则 [batch, W, H]` 将被错误地缩减为 `[batch, W]`。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:20:34.599679Z",
     "iopub.status.busy": "2023-11-07T23:20:34.599403Z",
     "iopub.status.idle": "2023-11-07T23:20:34.604169Z",
     "shell.execute_reply": "2023-11-07T23:20:34.603476Z"
    },
    "id": "R144Wci782ix"
   },
   "outputs": [],
   "source": [
    "with strategy.scope():\n",
    "  # Set reduction to `NONE` so you can do the reduction afterwards and divide by\n",
    "  # global batch size.\n",
    "  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(\n",
    "      from_logits=True,\n",
    "      reduction=tf.keras.losses.Reduction.NONE)\n",
    "  def compute_loss(labels, predictions):\n",
    "    per_example_loss = loss_object(labels, predictions)\n",
    "    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "w8y54-o9T2Ni"
   },
   "source": [
    "## 定义衡量指标以跟踪损失和准确性\n",
    "\n",
    "这些指标可以跟踪测试的损失，训练和测试的准确性。 您可以使用`.result()`随时获取累积的统计信息。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:20:34.607480Z",
     "iopub.status.busy": "2023-11-07T23:20:34.607188Z",
     "iopub.status.idle": "2023-11-07T23:20:34.651807Z",
     "shell.execute_reply": "2023-11-07T23:20:34.651045Z"
    },
    "id": "zt3AHb46Tr3w"
   },
   "outputs": [],
   "source": [
    "with strategy.scope():\n",
    "  test_loss = tf.keras.metrics.Mean(name='test_loss')\n",
    "\n",
    "  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
    "      name='train_accuracy')\n",
    "  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
    "      name='test_accuracy')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "iuKuNXPORfqJ"
   },
   "source": [
    "## 训练循环"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:20:34.655720Z",
     "iopub.status.busy": "2023-11-07T23:20:34.655415Z",
     "iopub.status.idle": "2023-11-07T23:20:34.701685Z",
     "shell.execute_reply": "2023-11-07T23:20:34.700917Z"
    },
    "id": "OrMmakq5EqeQ"
   },
   "outputs": [],
   "source": [
    "# A model, an optimizer, and a checkpoint must be created under `strategy.scope`.\n",
    "with strategy.scope():\n",
    "  model = create_model()\n",
    "\n",
    "  optimizer = tf.keras.optimizers.Adam()\n",
    "\n",
    "  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:20:34.705711Z",
     "iopub.status.busy": "2023-11-07T23:20:34.705422Z",
     "iopub.status.idle": "2023-11-07T23:20:34.710858Z",
     "shell.execute_reply": "2023-11-07T23:20:34.710277Z"
    },
    "id": "3UX43wUu04EL"
   },
   "outputs": [],
   "source": [
    "def train_step(inputs):\n",
    "  images, labels = inputs\n",
    "\n",
    "  with tf.GradientTape() as tape:\n",
    "    predictions = model(images, training=True)\n",
    "    loss = compute_loss(labels, predictions)\n",
    "\n",
    "  gradients = tape.gradient(loss, model.trainable_variables)\n",
    "  optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
    "\n",
    "  train_accuracy.update_state(labels, predictions)\n",
    "  return loss \n",
    "\n",
    "def test_step(inputs):\n",
    "  images, labels = inputs\n",
    "\n",
    "  predictions = model(images, training=False)\n",
    "  t_loss = loss_object(labels, predictions)\n",
    "\n",
    "  test_loss.update_state(t_loss)\n",
    "  test_accuracy.update_state(labels, predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:20:34.714106Z",
     "iopub.status.busy": "2023-11-07T23:20:34.713811Z",
     "iopub.status.idle": "2023-11-07T23:21:26.467217Z",
     "shell.execute_reply": "2023-11-07T23:21:26.466297Z"
    },
    "id": "gX975dMSNw0e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 8 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 8 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
      "I0000 00:00:1699399241.180172  527756 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 8 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, Loss: 0.6526154279708862, Accuracy: 76.88500213623047, Test Loss: 0.463595449924469, Test Accuracy: 83.1300048828125\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2, Loss: 0.4034326672554016, Accuracy: 85.66166687011719, Test Loss: 0.3861437141895294, Test Accuracy: 86.69999694824219\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3, Loss: 0.3507663607597351, Accuracy: 87.48999786376953, Test Loss: 0.3586585521697998, Test Accuracy: 87.69999694824219\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4, Loss: 0.31498587131500244, Accuracy: 88.6683349609375, Test Loss: 0.34132587909698486, Test Accuracy: 87.88999938964844\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5, Loss: 0.29295358061790466, Accuracy: 89.3949966430664, Test Loss: 0.32417431473731995, Test Accuracy: 88.3499984741211\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 6, Loss: 0.2780560553073883, Accuracy: 89.86499786376953, Test Loss: 0.3235797882080078, Test Accuracy: 88.6300048828125\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 7, Loss: 0.26235538721084595, Accuracy: 90.52333068847656, Test Loss: 0.2898319959640503, Test Accuracy: 89.52000427246094\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 8, Loss: 0.25144392251968384, Accuracy: 90.84666442871094, Test Loss: 0.295357882976532, Test Accuracy: 89.11000061035156\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 9, Loss: 0.23802918195724487, Accuracy: 91.27166748046875, Test Loss: 0.2847434878349304, Test Accuracy: 89.9800033569336\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10, Loss: 0.22753538191318512, Accuracy: 91.67832946777344, Test Loss: 0.2718784213066101, Test Accuracy: 90.09000396728516\n"
     ]
    }
   ],
   "source": [
    "# `run` replicates the provided computation and runs it\n",
    "# with the distributed input.\n",
    "@tf.function\n",
    "def distributed_train_step(dataset_inputs):\n",
    "  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))\n",
    "  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,\n",
    "                         axis=None)\n",
    "\n",
    "@tf.function\n",
    "def distributed_test_step(dataset_inputs):\n",
    "  return strategy.run(test_step, args=(dataset_inputs,))\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "  # TRAIN LOOP\n",
    "  total_loss = 0.0\n",
    "  num_batches = 0\n",
    "  for x in train_dist_dataset:\n",
    "    total_loss += distributed_train_step(x)\n",
    "    num_batches += 1\n",
    "  train_loss = total_loss / num_batches\n",
    "\n",
    "  # TEST LOOP\n",
    "  for x in test_dist_dataset:\n",
    "    distributed_test_step(x)\n",
    "\n",
    "  if epoch % 2 == 0:\n",
    "    checkpoint.save(checkpoint_prefix)\n",
    "\n",
    "  template = (\"Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, \"\n",
    "              \"Test Accuracy: {}\")\n",
    "  print(template.format(epoch + 1, train_loss,\n",
    "                         train_accuracy.result() * 100, test_loss.result(),\n",
    "                         test_accuracy.result() * 100))\n",
    "\n",
    "  test_loss.reset_states()\n",
    "  train_accuracy.reset_states()\n",
    "  test_accuracy.reset_states()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Z1YvXqOpwy08"
   },
   "source": [
    "以上示例中需要注意的事项：\n",
    "\n",
    "- 使用 `for x in ...` 构造来迭代 `train_dist_dataset` 和 `test_dist_dataset`。\n",
    "- 缩放损失是`distributed_train_step`的返回值。 这个值会在各个副本使用`tf.distribute.Strategy.reduce`的时候合并，然后通过`tf.distribute.Strategy.reduce`叠加各个返回值来跨批次。\n",
    "- `tf.keras.Metrics` 应该在由 `tf.distribute.Strategy.run` 执行的 `train_step` 和 `test_step` 内更新。\n",
    "- `tf.distribute.Strategy.run` 会从策略中的每个本地副本返回结果，您可以通过多种方式使用此结果。可以对它们执行 `tf.distribute.Strategy.reduce` 以获得聚合值。还可以通过执行 `tf.distribute.Strategy.experimental_local_results` 获得包含在结果中的值的列表，每个本地副本一个列表。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-q5qp31IQD8t"
   },
   "source": [
    "## 恢复最新的检查点并进行测试"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WNW2P00bkMGJ"
   },
   "source": [
    "使用 `tf.distribute.Strategy` 设置了检查点的模型可以使用或不使用策略进行恢复。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:21:26.472055Z",
     "iopub.status.busy": "2023-11-07T23:21:26.471231Z",
     "iopub.status.idle": "2023-11-07T23:21:26.537904Z",
     "shell.execute_reply": "2023-11-07T23:21:26.536850Z"
    },
    "id": "pg3B-Cw_cn3a"
   },
   "outputs": [],
   "source": [
    "eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
    "      name='eval_accuracy')\n",
    "\n",
    "new_model = create_model()\n",
    "new_optimizer = tf.keras.optimizers.Adam()\n",
    "\n",
    "test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:21:26.542694Z",
     "iopub.status.busy": "2023-11-07T23:21:26.541968Z",
     "iopub.status.idle": "2023-11-07T23:21:26.546618Z",
     "shell.execute_reply": "2023-11-07T23:21:26.545904Z"
    },
    "id": "7qYii7KUYiSM"
   },
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def eval_step(images, labels):\n",
    "  predictions = new_model(images, training=False)\n",
    "  eval_accuracy(labels, predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:21:26.550136Z",
     "iopub.status.busy": "2023-11-07T23:21:26.549634Z",
     "iopub.status.idle": "2023-11-07T23:21:27.259427Z",
     "shell.execute_reply": "2023-11-07T23:21:27.258508Z"
    },
    "id": "LeZ6eeWRoUNq"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy after restoring the saved model without strategy: 89.9800033569336\n"
     ]
    }
   ],
   "source": [
    "checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)\n",
    "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))\n",
    "\n",
    "for images, labels in test_dataset:\n",
    "  eval_step(images, labels)\n",
    "\n",
    "print('Accuracy after restoring the saved model without strategy: {}'.format(\n",
    "    eval_accuracy.result() * 100))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EbcI87EEzhzg"
   },
   "source": [
    "## 迭代一个数据集的替代方法\n",
    "\n",
    "### 使用迭代器\n",
    "\n",
    "如果要迭代给定的步数而不是遍历整个数据集，可以使用 `iter` 调用创建一个迭代器，并在该迭代器上显式地调用 `next`。您可以选择在 `tf.function` 内部和外部迭代数据集。下面是一个小代码段，演示了使用迭代器在 `tf.function` 外部迭代数据集。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:21:27.263330Z",
     "iopub.status.busy": "2023-11-07T23:21:27.263046Z",
     "iopub.status.idle": "2023-11-07T23:21:34.104233Z",
     "shell.execute_reply": "2023-11-07T23:21:34.103427Z"
    },
    "id": "7c73wGC00CzN"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10, Loss: 0.20245365798473358, Accuracy: 92.5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10, Loss: 0.22390851378440857, Accuracy: 91.796875\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10, Loss: 0.2280225306749344, Accuracy: 91.5234375\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10, Loss: 0.2149706780910492, Accuracy: 92.265625\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10, Loss: 0.19533059000968933, Accuracy: 92.96875\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10, Loss: 0.20433831214904785, Accuracy: 92.421875\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10, Loss: 0.20081932842731476, Accuracy: 92.5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10, Loss: 0.22768919169902802, Accuracy: 91.484375\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10, Loss: 0.24364034831523895, Accuracy: 91.09375\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10, Loss: 0.20831342041492462, Accuracy: 92.96875\n"
     ]
    }
   ],
   "source": [
    "for _ in range(EPOCHS):\n",
    "  total_loss = 0.0\n",
    "  num_batches = 0\n",
    "  train_iter = iter(train_dist_dataset)\n",
    "\n",
    "  for _ in range(10):\n",
    "    total_loss += distributed_train_step(next(train_iter))\n",
    "    num_batches += 1\n",
    "  average_train_loss = total_loss / num_batches\n",
    "\n",
    "  template = (\"Epoch {}, Loss: {}, Accuracy: {}\")\n",
    "  print(template.format(epoch + 1, average_train_loss, train_accuracy.result() * 100))\n",
    "  train_accuracy.reset_states()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GxVp48Oy0m6y"
   },
   "source": [
    "### 在 tf.function 中迭代\n",
    "\n",
    "您还可以使用 `for x in ...` 构造在 `tf.function` 内部迭代整个输入 `train_dist_dataset`，或者像上面那样创建迭代器。下面的示例演示了使用 `@tf.function` 装饰器封装一个训练周期并在函数内部迭代 `train_dist_dataset`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:21:34.108371Z",
     "iopub.status.busy": "2023-11-07T23:21:34.107642Z",
     "iopub.status.idle": "2023-11-07T23:21:48.307028Z",
     "shell.execute_reply": "2023-11-07T23:21:48.306178Z"
    },
    "id": "-REzmcXv00qm"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py:462: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options.\n",
      "  warnings.warn(\"To make it possible to preserve tf.data options across \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 8 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, Loss: 0.21631334722042084, Accuracy: 91.98333740234375\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2, Loss: 0.20537973940372467, Accuracy: 92.4749984741211\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3, Loss: 0.19430002570152283, Accuracy: 92.9050064086914\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4, Loss: 0.18660134077072144, Accuracy: 93.16999816894531\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5, Loss: 0.1797751635313034, Accuracy: 93.44666290283203\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 6, Loss: 0.1693052053451538, Accuracy: 93.79000091552734\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 7, Loss: 0.16252641379833221, Accuracy: 94.12333679199219\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 8, Loss: 0.15544697642326355, Accuracy: 94.31666564941406\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 9, Loss: 0.149961918592453, Accuracy: 94.52833557128906\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10, Loss: 0.13953761756420135, Accuracy: 94.99166870117188\n"
     ]
    }
   ],
   "source": [
    "@tf.function\n",
    "def distributed_train_epoch(dataset):\n",
    "  total_loss = 0.0\n",
    "  num_batches = 0\n",
    "  for x in dataset:\n",
    "    per_replica_losses = strategy.run(train_step, args=(x,))\n",
    "    total_loss += strategy.reduce(\n",
    "      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)\n",
    "    num_batches += 1\n",
    "  return total_loss / tf.cast(num_batches, dtype=tf.float32)\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "  train_loss = distributed_train_epoch(train_dist_dataset)\n",
    "\n",
    "  template = (\"Epoch {}, Loss: {}, Accuracy: {}\")\n",
    "  print(template.format(epoch + 1, train_loss, train_accuracy.result() * 100))\n",
    "\n",
    "  train_accuracy.reset_states()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MuZGXiyC7ABR"
   },
   "source": [
    "### 跟踪副本中的训练的损失\n",
    "\n",
    "注意：作为通用的规则，您应该使用`tf.keras.Metrics`来跟踪每个样本的值以避免它们在副本中合并。\n",
    "\n",
    "由于执行的损失缩放计算，不建议使用 `tf.keras.metrics.Mean` 来跟踪不同副本的训练损失。\n",
    "\n",
    "例如，如果您运行具有以下特点的训练作业：\n",
    "\n",
    "- 两个副本\n",
    "- 在每个副本上处理两个例子\n",
    "- 产生的损失值：每个副本为[2,3]和[4,5]\n",
    "- 全局批次大小 = 4\n",
    "\n",
    "通过损失缩放，您可以通过添加损失值来计算每个副本上的每个样本的损失值，然后除以全局批量大小。 在这种情况下：`（2 + 3）/ 4 = 1.25`和`（4 + 5）/ 4 = 2.25`。\n",
    "\n",
    "如果使用 `tf.keras.metrics.Mean` 来跟踪两个副本的损失，结果会有所不同。在此示例中，您最终会得到一个 `total` 为 3.50 和 `count` 为 2 的结果，在指标上调用 `result()` 时，您将得到 `total`/`count` = 1.75。使用 `tf.keras.Metrics` 计算的损失将按等于同步副本数的附加因子进行缩放。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xisYJaV9KZTN"
   },
   "source": [
    "### 例子和教程\n",
    "\n",
    "以下是一些使用自定义训练循环来分发策略的示例：\n",
    "\n",
    "1. 分布式训练指南\n",
    "2. [DenseNet](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/densenet/distributed_train.py) 使用 `MirroredStrategy`的例子。\n",
    "3. 使用 <code>MirroredStrategy</code> 和 `TPUStrategy` 训练的 <a>BERT</a> 示例。此示例对于理解如何在分布式训练等过程中从检查点加载并生成定期检查点特别有帮助。\n",
    "4. [NCF](https://github.com/tensorflow/models/blob/master/official/recommendation/ncf_keras_main.py) 使用 `MirroredStrategy` 来启用 `keras_use_ctl` 标记。\n",
    "5. [NMT](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/nmt_with_attention/distributed_train.py) 使用 `MirroredStrategy`来训练的例子。\n",
    "\n",
    "可以在[分布策略指南](../../guide/distributed_training.ipynb)的*示例和教程*下找到更多示例。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6hEJNsokjOKs"
   },
   "source": [
    "## 下一步\n",
    "\n",
    "- 在您的模型上尝试新的 `tf.distribute.Strategy` API。\n",
    "- 访问[使用 `tf.function` 和 TensorFlow Profiler 提升性能](../../guide/function.ipynb)指南，详细了解优化 TensorFlow 模型性能的工具。\n",
    "- 查看 [TensorFlow 中的分布式训练](../../guide/distributed_training.ipynb)指南，其中提供了可用分布策略的概述。"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "custom_training.ipynb",
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
